#include <iostream>

#include "core/log/Log.h"
#include "engine/ml/MLEngine.h"
//
// Created by neo on 25-5-7.
//
int main(int argc, char *argv[]) {
    std::cout << "Inference demo" << std::endl;

    MLEngine mle;
    if (!mle.Init()) {
        std::cerr << "Failed to initialize engine" << std::endl;
        return EXIT_FAILURE;
    }

    const std::vector<float> input = {
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
    };
    const auto reluInputMat = mle.CreateMatrix(16, 16, input);
    const auto reluOutputMat = mle.CreateMatrix(16, 16);
    // mle.ReLU(reluInputMat, reluOutputMat);
    // mle.Sigmoid(reluInputMat, reluOutputMat);
    // mle.Tanh(reluInputMat, reluOutputMat);
    // mle.Softmax(reluInputMat, reluOutputMat);
    mle.GELU(reluInputMat, reluOutputMat);
    const std::vector<float> mat1 = {
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
    };
    const std::vector<float> mat2 = {
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
        -1, 0, 0.4, 0.3, 0.2, 0.1, -0.7, 0.3, 0.2, 0.1, -0.7, 0.3, 0.3, 0.2, 0.1, -0.7,
    };

    const auto gemmInputMat1 = mle.CreateMatrix(16, 16, mat1);
    const auto gemmInputMat2 = mle.CreateMatrix(16, 16, mat2);
    const auto gemmOutputMat = mle.CreateMatrix(16, 16);

    mle.MatMul(gemmInputMat1, gemmInputMat2, gemmOutputMat);

    mle.Compute();

    reluOutputMat->Print();
    gemmOutputMat->Print();
}
